"""
Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
SPDX-License-Identifier: MIT
"""

import argparse
import json
import logging
import os

from pymysql import InternalError
import sys 

from sub_platforms.sql_opt.env.rds_env import OpenMySQLEnv
from sub_platforms.sql_opt.videx import videx_logging
from sub_platforms.sql_opt.videx.videx_metadata import fetch_all_meta_for_videx, \
    construct_videx_task_meta_from_local_files, fetch_all_meta_with_one_file, \
    collect_sample_data_for_tables, save_sample_data_to_files, construct_meta_request_with_samples
from sub_platforms.sql_opt.videx.videx_service import create_videx_env_multi_db, \
    post_add_videx_meta
from sub_platforms.sql_opt.videx.videx_utils import VIDEX_IP_WHITE_LIST


def get_usage_message(args, videx_ip, videx_port, videx_db, videx_user, videx_pwd, videx_server_ip_port):
    base_msg = f"Build env finished. Your VIDEX server is {videx_server_ip_port}."

    mysql57_msg = ("-- Note, if your MySQL version is 5.x, please setup/clear the environment "
                   "before and after your connecting as follows:\n"
                   f"mysql -h{videx_ip} -P{videx_port} -u{videx_user} -p{videx_pwd} < setup_mysql57_env.sql\n"
                   f"mysql -h{videx_ip} -P{videx_port} -u{videx_user} -p{videx_pwd} < clear_mysql57_env.sql\n")

    if args.task_id:
        videx_options = json.dumps({"task_id": args.task_id})
        return (f"{base_msg}\n"
                f"To use VIDEX, please set the following variable before explaining your SQL:\n" + "-" * 20 +
                "\n"
                f"-- Connect VIDEX-MySQL: mysql -h{videx_ip} -P{videx_port} -u{videx_user} -p{videx_pwd} -D{videx_db}\n"
                f"USE {videx_db};\n"
                f"SET @VIDEX_SERVER='{videx_server_ip_port}'; -- For MySQL \n"
                f"SET @VIDEX_OPTIONS='{videx_options}';\n"
                f"SET SESSION VIDEX_SERVER_IP='127.0.0.1:5001'; -- For MariaDB \n"
                f"-- EXPLAIN YOUR_SQL;\n"
                f"{mysql57_msg}")
    else:
        return (f"{base_msg}\n"
                f"You are running in non-task mode.\n"
                f"To use VIDEX, please set the following variable before explaining your SQL:\n" + "-" * 20 +
                "\n"
                f"-- Connect VIDEX-MySQL: mysql -h{videx_ip} -P{videx_port} -u{videx_user} -p{videx_pwd} -D{videx_db}\n"
                f"USE {videx_db};\n"
                f"SET @VIDEX_SERVER='{videx_server_ip_port}'; -- For MySQL \n"
                f"SET SESSION VIDEX_SERVER_IP='127.0.0.1:5001'; -- For MariaDB \n"
                f"-- EXPLAIN YOUR_SQL;\n"
                f"{mysql57_msg}")


def parse_connection_info(info):
    target_ip, target_port, target_db, target_user, target_pwd = info.split(':')
    return target_ip, int(target_port), target_db, target_user, target_pwd


if __name__ == "__main__":
    """
    Collect data from `target_ins`, then import and create a Videx environment.

    Specify the connection method for the target MySQL and the target database to set up the Videx environment. 
    The process includes three steps:
    1. Collect necessary metadata, schema, and statistics information for Videx from the source MySQL. 
        If any files or directories are specified, loading directly from them.
    2. In `VIDEX-MySQL`, create database tables corresponding to those in `target-mysql`, but replace the engine with VIDEX.
    3. Import metadata and statistical information into the Videx server.
    
    Once these preparatory steps are completed, `EXPLAIN SQL` can be executed on `VIDEX-MySQL` to simulate 
    the query plan generated by the target MySQL.

    Example:        
        python videx_build_env.py --target "127.0.0.1:13308:tpch:user:password" --videx_server 5002

    """
    # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>> >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
    parser = argparse.ArgumentParser(description='Collect data from target_ins and create videx environment.')
    parser.add_argument('--target', type=str, required=True,
                        help='Connection info for raw instance, in the format of "ip:port:db:user:password"')
    parser.add_argument('--videx', type=str,
                        help='Connection info for videx instance, in the format of "ip:port:db:user:password". '
                             'If not provided, it will be generated based on raw instance info.')
    parser.add_argument('--videx_server', type=str, default="5001",
                        help='Connection info for videx server, in the format of "[ip:]port". '
                             'If not provided, access "{videx_ip}:5001".')
    parser.add_argument('--tables', type=str, default=None,
                        help='Comma-separated list of table names to fetch. If not provided, fetching all tables. '
                             'e.g. customer,nation')
    parser.add_argument('--meta_path', type=str, default=None,
                        help='meta filepath to save pulled metadata.')
    parser.add_argument('--fetch_method', type=str, default='fetch', help='fetch, partial_fetch, sampling')
    parser.add_argument('--task_id', type=str, default=None,
                        help='task id is to distinguish different videx tasks, if they have same database names.')
    parser.add_argument('--hist_algo', type=str, default=None, 
                   help='Histogram algorithm: None (default), block_2phase')


    """
    videx_server_ip_port: IP and port information, Videx MySQL will inform Videx Python about this address and send Videx queries to it.
    task_id: Task ID used to differentiate metadata of different tasks or even different versions of the same task.
    """

    videx_logging.initial_config()
    args = parser.parse_args()

    # step 1: parse arguments
    target_ip, target_port, target_db, target_user, target_pwd = parse_connection_info(args.target)

    if args.videx:
        videx_ip, videx_port, videx_db, videx_user, videx_pwd = parse_connection_info(args.videx)
    else:
        # if no videx, the videx is located to the same instance with target mysql
        videx_ip, videx_port, videx_user, videx_pwd = target_ip, target_port, target_user, target_pwd
        videx_db = f'videx_{target_db}'

    if target_ip == videx_ip and target_port == videx_port:
        assert target_db != videx_db, (f"Since `target_ins` and `videx_ins` are the same instance, "
                                       f"their `db` properties must not be the same.")
    if ':' in args.videx_server:
        videx_server_ip_port = args.videx_server
    else:
        videx_server_ip_port = f"{videx_ip}:{args.videx_server}"

    target_env = OpenMySQLEnv(ip=target_ip, port=target_port, usr=target_user, pwd=target_pwd, db_name=target_db,
                              read_timeout=300, write_timeout=300, connect_timeout=10)
    try:
        videx_env = OpenMySQLEnv(ip=videx_ip, port=videx_port, usr=videx_user, pwd=videx_pwd, db_name=videx_db,
                                 read_timeout=300, write_timeout=300, connect_timeout=10)
    except Exception as e:
        if f"Unknown database '{videx_db}'" in str(e):
            videx_env = OpenMySQLEnv(ip=videx_ip, port=videx_port, usr=videx_user, pwd=videx_pwd, db_name=None,
                                     read_timeout=300, write_timeout=300, connect_timeout=10)
            videx_env.execute(f"CREATE DATABASE IF NOT EXISTS `{videx_db}`")
            videx_env.set_default_db(videx_db)
        else:
            raise

    if args.tables:
        all_table_names = args.tables.split(',')
    else:
        all_table_names = None  # No restriction, fetch all tables from target database

    if args.meta_path:
        meta_path = args.meta_path
        if os.path.dirname(meta_path):
            os.makedirs(os.path.dirname(meta_path), exist_ok=True)
    else:
        # Load the existing meta file or save it to a file only when the meta_path is explicitly defined.
        meta_path = None
    logging.info(f"metadata file is {meta_path}")

    # step 2: fetch or read metadata and statistics
    task_id = f"task_id_videx_on_{target_db}"
    if args.fetch_method in ['fetch', 'partial_fetch']:
        # N.B.: Fetching NDV and histogram data can be costly. Ensure the target IP is offline or permitted.
        VIDEX_IP_WHITE_LIST.append(target_ip)
        # TODO fetch ndv and histogram may be costly, if partial_fetch, we skip fetching ndv and histogram
        if args.fetch_method == 'partial_fetch':
            pass  # more tests are required before supporting it
        files = fetch_all_meta_with_one_file(meta_path=meta_path,
                                             env=target_env, target_db=target_db, all_table_names=all_table_names,
                                             n_buckets=16, hist_force=True,
                                             hist_mem_size=200000000, drop_hist_after_fetch=True,
                                             hist_algo=args.hist_algo)
        stats_file_dict, hist_file_dict, ndv_single_file_dict, ndv_mulcol_file_dict = files
        meta_request = construct_videx_task_meta_from_local_files(task_id=args.task_id,
                                                                  videx_db=videx_db,
                                                                  stats_file=stats_file_dict,
                                                                  hist_file=hist_file_dict,
                                                                  ndv_single_file=ndv_single_file_dict,
                                                                  ndv_mulcol_file=ndv_mulcol_file_dict,
                                                                  gt_rec_in_ranges_file=None,
                                                                  gt_req_resp_file=None,
                                                                  raise_error=True,
                                                                  )
    elif args.fetch_method == 'sampling':
        # This method will generate metadata from the sample data.
        # Additionally, the sample data will be employed to estimate the ndv (the number of distinct values) and cardinality.
        logging.info("Using sampling mode to collect metadata")

        if all_table_names is None:
            sql = f"""
                SELECT TABLE_NAME 
                FROM information_schema.TABLES 
                WHERE table_schema = '{target_db}' and ENGINE = 'InnoDB'
            """
            table_df = target_env.query_for_dataframe(sql)
            all_table_names = [str(name).lower() for name in table_df['TABLE_NAME'].tolist()]
            logging.info(f"Retrieved table names: {all_table_names}")

        if not all_table_names:
            logging.error("No tables found for sampling")
            sys.exit(1)

        
        files = fetch_all_meta_with_one_file(
            meta_path=meta_path,
            env=target_env, 
            target_db=target_db, 
            all_table_names=all_table_names,
            n_buckets=16, 
            hist_force=True,
            hist_mem_size=200000000, 
            drop_hist_after_fetch=True
        )
        stats_file_dict, hist_file_dict, ndv_single_file_dict, ndv_mulcol_file_dict = files
        
        # Collect sampling data
        sample_file_dict = collect_sample_data_for_tables(
            env=target_env,
            target_db=target_db,
            all_table_names=all_table_names,
            sample_size=1000
        )
        
        # Save the sampled data to a file
        save_dir = f"/tmp/videx_samples/{videx_db}"
        sample_file_info_dict = save_sample_data_to_files(
            sample_file_dict=sample_file_dict,
            videx_db=videx_db,
            save_dir=save_dir
        )
        
        # Construct the metadata request, including the sampled data
        meta_request = construct_meta_request_with_samples(
            task_id=args.task_id,
            videx_db=videx_db,
            stats_file_dict=stats_file_dict,
            sample_file_dict=sample_file_dict,
            sample_file_info_dict=sample_file_info_dict
        )

    else:
        raise NotImplementedError(f"Fetching method `{args.fetch_method}` not implemented, "
                                  f"only support `analyze`, `sampling`.")

    # step 3: create tables into VIDEX-MySQL, post metadata and statistics to VIDEX-Server
    # 向 VIDEX-MySQL 中建表
    create_videx_env_multi_db(videx_env, meta_dict=meta_request.meta_dict, )
    # 向 VIDEX-Server 中导入数据
    response = post_add_videx_meta(meta_request, videx_server_ip_port=videx_server_ip_port, use_gzip=True)
    assert response.status_code == 200

    logging.info(f"metadata file is {meta_path}")
    logging.info(get_usage_message(args, videx_ip, videx_port, videx_db, videx_user, videx_pwd, videx_server_ip_port))
